import numpy as np
import matplotlib.pyplot as plt
import json
import statistics

NUM_TRIALS = 5

# params
datasets = ["kdda", "rcv1_test.binary", "url_combined"]
lambdas = ["1e-3", "1e-4", "1e-5"]
memory_bounds = ["2KB", "4KB", "8KB", "16KB", "32KB"]
jl_alg_string = "jl_recovery_sketch"
logistic_sketch_alg_string = "logistic_sketch"
black_box_alg_string = "black_box_reduction"
list_of_k = [0, 20, 40, 60, 80, 100, 120]

#directory/file strings
def results_folder(algorithm, trial_num):
	return "JSON_Results_%s_Trial_%d" % (algorithm, trial_num)


def results_file_string(algorithm, trial_num, memory_bound, l2reg, dataset):
	return "%s_Trial_%d_%s_l2reg_%s_Dataset_%s" % (algorithm, trial_num, memory_bound, l2reg, dataset)


def results_path_string(algorithm, trial_num, memory_bound, l2reg, dataset):
	folder = results_folder(algorithm=algorithm, trial_num=trial_num)
	file = results_file_string(algorithm=algorithm, trial_num=trial_num, memory_bound=memory_bound, l2reg=l2reg, dataset=dataset)
	return "%s/%s" % (folder, file)

# Create weight recovery plots:
# For each dataset/l2reg/memory configuration,
# the relative error metric for the top k logistic
# regression weight recovery for both JL sketch
# and logistic_sketch. Plot the median values of the
# relative error metric, and the max and min.


# Computes the relative error weight recovery metric.
# D = 512 was used in main paper. D = "full_weights" 
# compares to full weight vector. Assumes k << D.
def compute_weight_recovery_error(true_indices, true_weights, est_indices, est_weights, k, D):
	# Compute the top k weights of est_weights.
	# Note that weights are sorted in decreasing order.
	true_diff_dict = {}
	est_diff_dict = {}

	if (D == "full_weights"):
		num_indices_compared = len(true_indices)
	else:
		num_indices_compared = 512

	for i in range(num_indices_compared):
		idx = true_indices[i]
		weight = true_weights[i]
		true_diff_dict[idx] = weight
		est_diff_dict[idx] = weight

	for i in range(k):
		true_idx = true_indices[i]
		true_weight = true_weights[i]
		true_diff_dict[true_idx] -= true_weight

		est_idx = est_indices[i]
		est_weight = est_weights[i]
		if (est_idx in est_diff_dict):
			est_diff_dict[est_idx] -= est_weight
		else:
			est_diff_dict[est_idx] = -1 * est_weight

	true_error = 0
	for idx in true_diff_dict:
		true_error += true_diff_dict[idx] ** 2

	est_error = 0
	for idx in est_diff_dict:
		est_error += est_diff_dict[idx] ** 2

	return est_error/true_error

# D is a string, either "full_weights" or "512".
def weight_recovery_plot(dataset, l2reg, memory_bound, D):
	# Get the top weights of logistic regression
	with open("JSON_Results_logistic_regression_full/logistic_regression_full_l2reg_%s_Dataset_%s" % (l2reg, dataset)) as full_logistic_json_file:
		full_logistic_data = json.load(full_logistic_json_file)
		full_logistic_data = full_logistic_data["results"]
		true_indices = full_logistic_data["top_indices"]
		true_weights = full_logistic_data["top_weights"]

	jlsketch_dict_list = []
	logistic_sketch_dict_list = []
	black_box_dict_list = []
	for i in range(1, NUM_TRIALS + 1):
		jlsketch_path_string = results_path_string(algorithm=jl_alg_string, trial_num=i, memory_bound=memory_bound, l2reg=l2reg, dataset=dataset)
		logistic_sketch_path_string = results_path_string(algorithm=logistic_sketch_alg_string, trial_num=i, memory_bound=memory_bound, l2reg=l2reg, dataset=dataset)
		black_box_path_string = results_path_string(algorithm=black_box_alg_string, trial_num=i, memory_bound=memory_bound, l2reg=l2reg, dataset=dataset)

		with open(jlsketch_path_string) as jlsketch_json_file:
			jlsketch_json_dict = json.load(jlsketch_json_file)
			jlsketch_json_dict = jlsketch_json_dict["results"]
			jlsketch_dict_list.append(jlsketch_json_dict)

		with open(logistic_sketch_path_string) as logistic_sketch_json_file:
			logistic_sketch_json_dict = json.load(logistic_sketch_json_file)
			logistic_sketch_json_dict = logistic_sketch_json_dict["results"]
			logistic_sketch_dict_list.append(logistic_sketch_json_dict)

		with open(black_box_path_string) as black_box_json_file:
			black_box_json_dict = json.load(black_box_json_file)
			black_box_json_dict = black_box_json_dict["results"]
			black_box_dict_list.append(black_box_json_dict)

	# Collect jlsketch
	jlsketch_median_error = []
	jlsketch_min_error = []
	jlsketch_max_error = []
	for k in list_of_k:
		list_of_errors = []
		for trial in range(1, NUM_TRIALS + 1):
			current_trial_dict = jlsketch_dict_list[trial - 1]
			est_indices = current_trial_dict["top_indices"]
			est_weights = current_trial_dict["top_weights"]
			relative_error = compute_weight_recovery_error(true_indices=true_indices, true_weights=true_weights, est_indices=est_indices, est_weights=est_weights, k=k, D=D)
			list_of_errors.append(relative_error)
		jlsketch_median_error.append(statistics.median(list_of_errors))
		jlsketch_min_error.append(min(list_of_errors))
		jlsketch_max_error.append(max(list_of_errors))

	# Collect logistic_sketch
	logistic_sketch_median_error = []
	logistic_sketch_min_error = []
	logistic_sketch_max_error = []
	for k in list_of_k:
		list_of_errors = []
		for trial in range(1, NUM_TRIALS + 1):
			current_trial_dict = logistic_sketch_dict_list[trial - 1]
			est_indices = current_trial_dict["top_indices"]
			est_weights = current_trial_dict["top_weights"]
			relative_error = compute_weight_recovery_error(true_indices=true_indices, true_weights=true_weights, est_indices=est_indices, est_weights=est_weights, k=k, D=D)
			list_of_errors.append(relative_error)
		logistic_sketch_median_error.append(statistics.median(list_of_errors))
		logistic_sketch_min_error.append(min(list_of_errors))
		logistic_sketch_max_error.append(max(list_of_errors))

	# Collect black box
	black_box_median_error = []
	black_box_min_error = []
	black_box_max_error = []
	for k in list_of_k:
		list_of_errors = []
		for trial in range(1, NUM_TRIALS + 1):
			current_trial_dict = black_box_dict_list[trial - 1]
			est_indices = current_trial_dict["top_indices"]
			est_weights = current_trial_dict["top_weights"]
			relative_error = compute_weight_recovery_error(true_indices=true_indices, true_weights=true_weights, est_indices=est_indices, est_weights=est_weights, k=k, D=D)
			list_of_errors.append(relative_error)
		black_box_median_error.append(statistics.median(list_of_errors))
		black_box_min_error.append(min(list_of_errors))
		black_box_max_error.append(max(list_of_errors))

	# Plot JL Sketch
	jlsketch_y_value = np.array(jlsketch_median_error)
	jlsketch_error_bars = np.vstack([jlsketch_min_error, jlsketch_max_error])
	jlsketch_error_bars -= jlsketch_y_value
	jlsketch_error_bars = np.abs(jlsketch_error_bars)
	plt.errorbar(x=list_of_k, y=jlsketch_y_value, yerr=jlsketch_error_bars, color="blue")

	# Plot Logistic Sketch
	logistic_sketch_y_value = np.array(logistic_sketch_median_error)
	logistic_sketch_error_bars = np.vstack([logistic_sketch_min_error, logistic_sketch_max_error])
	logistic_sketch_error_bars -= logistic_sketch_y_value
	logistic_sketch_error_bars = np.abs(logistic_sketch_error_bars)
	plt.errorbar(x=list_of_k, y=logistic_sketch_median_error, yerr=logistic_sketch_error_bars, color="red")

	# Plot black box
	black_box_y_value = np.array(black_box_median_error)
	black_box_error_bars = np.vstack([black_box_min_error, black_box_max_error])
	black_box_error_bars -= black_box_y_value
	black_box_error_bars = np.abs(black_box_error_bars)
	plt.errorbar(x=list_of_k, y=black_box_median_error, yerr=black_box_error_bars, color="green")

	# Show
	plt.title("Dataset: %s; Lambda: %s, Memory Bound: %s, D: %s" % (dataset, l2reg, memory_bound, D))
	plt.xlabel("K")
	plt.ylabel("Relative error")
	plt.savefig("%s_%s_%s_%s.png" % (dataset, l2reg, memory_bound, D))
	plt.close()

for dataset in datasets:
	for l2reg in lambdas:
		for memory_bound in memory_bounds:
			for D in ["512"]:
				print("plot")
				weight_recovery_plot(dataset=dataset, l2reg=l2reg, memory_bound=memory_bound, D=D)
